from typing import Any
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
import torch 
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np 
from LLM_utils import sample_sentences, calculate_sequence_probability
import wandb 
import time 
from pdb import set_trace


class LLM_MetropolisHastings:
    def __init__(self, reference_LLM, reference_tokenizer, initial_sequence, 
                proposal_prompt, target_distribution, 
                max_length = 100, top_k=50,top_p=0.95,
                normalization_strategy = "None",
                use_log_probabilities = False,
                ignore_proposal_distribution = False,
                wandb_logging = False):
        """
        Initialize the Metropolis-Hastings sampler.
        
        reference_LLM: Language model to use as reference distribution
        reference_tokenizer: Tokenizer for reference_LLM
        initial_sequence: Initial sequence to start sampling from
        proposal_prompt: Prompt to use for proposal distribution
        target_distribution: Function to compute the probability density of the target distribution.
        """
        self.target_distribution = target_distribution
        self.proposal_prompt = proposal_prompt
        self.reference_LLM = reference_LLM
        self.reference_tokenizer = reference_tokenizer
        self.wandb_logging = wandb_logging
        self.max_length = max_length
        self.top_k = top_k
        self.top_p = top_p
        self.normalization_strategy = normalization_strategy
        self.use_log_probabilities = use_log_probabilities
        self.ignore_proposal_distribution = ignore_proposal_distribution

        self.current_sequence = initial_sequence
        self.current_sequence_probability_dict = self.target_distribution.evaluate(sequence = self.current_sequence, current_sequence = None)
        self.current_round = 1

        self.advertiser_cumulative_Myerson_payments = np.array([0 for i in range(len(self.target_distribution.advertiser_prompts))], dtype= float)
        self.advertiser_cumulative_GSP_payments = np.array([0 for i in range(len(self.target_distribution.advertiser_prompts))], dtype= float)
        self.advertiser_cumulative_indirect_utility = np.array([0 for i in range(len(self.target_distribution.advertiser_prompts))], dtype= float)
        self.reference_llm_cumulative_Myerson_payments = 0
        self.reference_llm_cumulative_GSP_payments = 0

        if wandb_logging:
            # wandb.init(project="LLM_MH", entity="cs-269-spring-21")
            

            wandb.define_metric("round")
            wandb.define_metric("current sequence", step_metric= "round")
            wandb.define_metric("proposal sequence", step_metric= "round")
            wandb.define_metric("current probability", step_metric= "round")
            wandb.define_metric("current log probability", step_metric= "round")
            wandb.define_metric("proposal probability", step_metric= "round")
            wandb.define_metric("proposal log probability", step_metric= "round")
            wandb.define_metric("proposal reference llm probability", step_metric= "round")
            wandb.define_metric("proposal reference llm log probability", step_metric= "round")
            wandb.define_metric("advertiser welfare", step_metric= "round")
            wandb.define_metric("number of tokens", step_metric= "round")
            wandb.define_metric("number of bytes", step_metric= "round")
            wandb.define_metric("number of characters", step_metric= "round")
            wandb.define_metric("acceptance probability", step_metric= "round")
            wandb.define_metric("accepted proposal", step_metric= "round")
            wandb.define_metric("current sequence character normalized log probability", step_metric= "round")
            wandb.define_metric("current sequence token normalized log probability", step_metric= "round")
            wandb.define_metric("current sequence character normalized probability", step_metric= "round")
            wandb.define_metric("current sequence token normalized probability", step_metric= "round")
            wandb.define_metric("proposal generation time", step_metric= "round")
            wandb.define_metric("proposal evaluation time", step_metric= "round")
            wandb.define_metric("proposal cond current prob", step_metric= "round")
            wandb.define_metric("current cond proposal prob", step_metric= "round")
            wandb.define_metric("proposal distribution fraction", step_metric= "round")
            wandb.define_metric("probability fraction", step_metric= "round")
            wandb.define_metric("normalization fraction", step_metric= "round")
            wandb.define_metric("probability log fraction", step_metric= "round")
            wandb.define_metric('myerson case 1 percentage', step_metric= "round")
            wandb.define_metric('myerson case 2 percentage', step_metric= "round")
            wandb.define_metric('myerson case 3 percentage', step_metric= "round")
            wandb.define_metric('myerson case 4 percentage', step_metric= "round")
            wandb.define_metric('myerson case 5 percentage', step_metric= "round")
            wandb.define_metric('myerson case 6 percentage', step_metric= "round")


    
    def step(self):
        """
        Perform one step of the Metropolis-Hastings algorithm.
        """

        proposal_distribution_sequence = self.proposal_prompt + self.current_sequence
        
        # Draw a sample from the proposal distribution
        # ensure that it is different from the current sequence
        samples_required = 1
        start = time.time() 
        proposal = sample_sentences(model = self.reference_LLM, tokenizer = self.reference_tokenizer, input_sequence = proposal_distribution_sequence, 
                                        max_length = self.max_length, top_k= self.top_k,top_p=self.top_p,num_return_sequences=1, print_output= False)
        proposal = self.reference_tokenizer.decode(proposal[0], skip_special_tokens=True)
        while proposal == self.current_sequence:
            proposal = sample_sentences(model = self.reference_LLM, tokenizer = self.reference_tokenizer, input_sequence = proposal_distribution_sequence, 
                                        max_length = self.max_length, top_k= self.top_k,top_p=self.top_p,num_return_sequences=1, print_output= False)
            proposal = self.reference_tokenizer.decode(proposal[0], skip_special_tokens=True)
            samples_required += 1

        end = time.time()
        sample_time = end - start
        
        # Decode the sample and compute its probability under the target distribution
        start = time.time()
        
        proposal_prob_dict = self.target_distribution.evaluate(sequence = proposal, current_sequence = self.current_sequence)
        end = time.time()
        eval_time = end - start
        
        # --- Compute the acceptance probability --- #
        proposal_cond_current_log_prob = proposal_prob_dict['proposal_cond_current_log_prob'] 
        current_cond_proposal_log_prob = proposal_prob_dict['current_cond_proposal_log_prob'] 

        base_proposal_log_probability = proposal_prob_dict['target_log_prob']
        base_current_log_probability = self.current_sequence_probability_dict['target_log_prob']

        if self.normalization_strategy == "token":
            normalization_factor_proposal = proposal_prob_dict['number_of_tokens']
            normalization_factor_current =  self.current_sequence_probability_dict['number_of_tokens']
        elif self.normalization_strategy == "character":
            normalization_factor_proposal = proposal_prob_dict['number_of_characters']
            normalization_factor_current =  self.current_sequence_probability_dict['number_of_characters']
        elif self.normalization_strategy == "byte":
            normalization_factor_proposal = proposal_prob_dict['number_of_bytes']
            normalization_factor_current = self.current_sequence_probability_dict['number_of_bytes']
        elif self.normalization_strategy == "None":
            normalization_factor_proposal = 1
            normalization_factor_current = 1
        else:
            raise ValueError(f'Normalization strategy {self.normalization_strategy} not recognized')
        
        if self.ignore_proposal_distribution:
            proposal_cond_current_prob = 1
            current_cond_proposal_prob = 1

        
        
        probability_fraction = np.exp(base_proposal_log_probability - base_current_log_probability)  # use log probabilities for more stability and exponentiate at the end
        normalization_fraction = normalization_factor_current / normalization_factor_proposal
        proposal_distribution_fraction = np.exp(current_cond_proposal_log_prob - proposal_cond_current_log_prob) # use log probabilities for more stability and exponentiate at the end
        # probability_log_fraction = base_proposal_log_probability / base_current_log_probability

        if self.use_log_probabilities:
            # acceptance_probability = probability_log_fraction * normalization_fraction * proposal_distribution_fraction
            raise NotImplementedError  # NOTE: This makes no sense in terms of theory, thus turned off. 
        else:
            acceptance_probability = probability_fraction * normalization_fraction * proposal_distribution_fraction
        
        acceptance_probability = min(1, acceptance_probability)
        print(f'Acceptance probability: {acceptance_probability}, proposal probability fraction: {probability_fraction}, normalization fraction: {normalization_fraction}, proposal distribution fraction: {proposal_distribution_fraction}')
        
        wandb_dict = {
            'round': self.current_round,
            'current sequence': self.current_sequence,
            'proposal sequence': proposal,
            'current probability': self.current_sequence_probability_dict['target_prob'],
            'current log probability': self.current_sequence_probability_dict['target_log_prob'],
            'current sequence character normalized log probability': self.current_sequence_probability_dict['target_log_prob'] / len(self.current_sequence),
            'current sequence token normalized log probability': self.current_sequence_probability_dict['target_log_prob'] / self.current_sequence_probability_dict['number_of_tokens'],
            'current sequence character normalized probability': self.current_sequence_probability_dict['target_prob'] / len(self.current_sequence),
            'current sequence token normalized probability': self.current_sequence_probability_dict['target_prob'] / self.current_sequence_probability_dict['number_of_tokens'],
            'proposal probability':  proposal_prob_dict['target_prob'],
            'proposal log probability': proposal_prob_dict['target_log_prob'],
            'acceptance probability': acceptance_probability,
            'proposal reference llm probability': proposal_prob_dict['reference_llm_log_prob'],
            'proposal reference llm log probability': proposal_prob_dict['reference_llm_prob'],
            'advertiser welfare': self.current_sequence_probability_dict['advertiser_welfare'],
            'number of tokens': self.current_sequence_probability_dict['number_of_tokens'],
            'number of bytes': self.current_sequence_probability_dict['number_of_bytes'],
            'number of characters': self.current_sequence_probability_dict['number_of_characters'],
            'proposal generation time': sample_time,
            'proposal evaluation time': eval_time,
            'samples to draw non-identity proposal': samples_required, 
            'proposal cond current prob': proposal_prob_dict['proposal_cond_current_prob'],
            'current cond proposal prob': proposal_prob_dict['current_cond_proposal_prob'],
            'proposal distribution fraction': proposal_distribution_fraction,
            'probability fraction': probability_fraction,
            'normalization fraction': normalization_fraction,
        }

        u_variable = np.random.rand()
        accepted_proposal = u_variable < acceptance_probability
        wandb_dict['accepted proposal'] = 1 if accepted_proposal else 0

        # Calulate payments and utilities
        myerson_payments_advertisers, myerson_payment_reference_llm, myerson_case_counter, gsp_payments_advertisers, gsp_payment_reference_llm = self.calculate_payments(self.current_sequence_probability_dict, 
                        proposal_prob_dict, u_variable, accepted_proposal)
        

        # update cumulative payments
        self.advertiser_cumulative_Myerson_payments += myerson_payments_advertisers
        self.advertiser_cumulative_GSP_payments += gsp_payments_advertisers
        self.reference_llm_cumulative_Myerson_payments += myerson_payment_reference_llm
        self.reference_llm_cumulative_GSP_payments += gsp_payment_reference_llm
        
        
        advertiser_rewards_proposal = proposal_prob_dict['advertiser_bid_weighted_log_probs']
        advertiser_rewards_current = self.current_sequence_probability_dict['advertiser_bid_weighted_log_probs']
        advertiser_bids = advertiser_rewards_proposal - advertiser_rewards_current

        reference_llm_reward_proposal= proposal_prob_dict['tau_weighted_reference_llm_log_prob'] 
        reference_llm_reward_current = self.current_sequence_probability_dict['tau_weighted_reference_llm_log_prob']
        reference_llm_bid = reference_llm_reward_proposal - reference_llm_reward_current
        

        for i in range(len(myerson_payments_advertisers)):
            advertiser_indirect_utility_MP = advertiser_bids[i] * acceptance_probability - myerson_payments_advertisers[i]
            advertiser_indirect_utility_GSP = advertiser_bids[i] * acceptance_probability - gsp_payments_advertisers[i]
            self.advertiser_cumulative_indirect_utility_MP[i] += advertiser_indirect_utility_MP
            self.advertiser_cumulative_indirect_utility_GSP[i] += advertiser_indirect_utility_GSP
            wandb_dict['myerson payments advertiser ' + str(i)] = myerson_payments_advertisers[i]
            wandb_dict['gsp payments advertiser ' + str(i)] = gsp_payments_advertisers[i]
            wandb_dict['cumulative myerson payments advertiser ' + str(i)] = self.advertiser_cumulative_Myerson_payments[i]
            wandb_dict['cumulative gsp payments advertiser ' + str(i)] = self.advertiser_cumulative_GSP_payments[i]
            wandb_dict['indirect utility MP advertiser ' + str(i)] = advertiser_indirect_utility_MP
            wandb_dict['indirect utility GSP advertiser ' + str(i)] = advertiser_indirect_utility_GSP
            wandb_dict['cumulative indirect utility MP advertiser ' + str(i)] = self.advertiser_cumulative_indirect_utility_MP[i]
            wandb_dict['cumulative indirect utility GSP advertiser ' + str(i)] = self.advertiser_cumulative_indirect_utility_GSP[i]


        # add Myerson payment cases to wandb_dict
        for i in range(len(myerson_case_counter)):
            wandb_dict[f'myerson case {i+1} percentage'] = myerson_case_counter[i] / np.sum(myerson_case_counter)

        wandb_dict['myerson payments reference llm'] = myerson_payment_reference_llm
        wandb_dict['gsp payments reference llm'] = gsp_payment_reference_llm
        wandb_dict['cumulative myerson payments reference llm'] = self.reference_llm_cumulative_Myerson_payments
        wandb_dict['cumulative gsp payments reference llm'] = self.reference_llm_cumulative_GSP_payments
        wandb_dict['cumulative myerson payments'] = np.sum(self.advertiser_cumulative_Myerson_payments) + self.reference_llm_cumulative_Myerson_payments
        wandb_dict['cumulative gsp payments'] = np.sum(self.advertiser_cumulative_GSP_payments) + self.reference_llm_cumulative_GSP_payments
        wandb_dict['cumulative myerson payments advertisers'] = np.sum(self.advertiser_cumulative_Myerson_payments)
        wandb_dict['cumulative gsp payments advertisers'] = np.sum(self.advertiser_cumulative_GSP_payments)

        wandb_dict['indirect utility MP reference llm'] = reference_llm_bid * acceptance_probability - myerson_payment_reference_llm
        wandb_dict['indirect utility GSP reference llm'] = reference_llm_bid * acceptance_probability - gsp_payment_reference_llm
        
        wandb_dict['total indirect utility MP'] = np.sum([wandb_dict['indirect utility MP advertiser ' + str(i)] for i in range(len(myerson_payments_advertisers))]) + wandb_dict['indirect utility MP reference llm']
        wandb_dict['total indirect utility GSP'] = np.sum([wandb_dict['indirect utility GSP advertiser ' + str(i)] for i in range(len(myerson_payments_advertisers))]) + wandb_dict['indirect utility GSP reference llm']
        wandb_dict['total payments MP'] = np.sum([wandb_dict['myerson payments advertiser ' + str(i)] for i in range(len(myerson_payments_advertisers))]) + wandb_dict['myerson payments reference llm']
        wandb_dict['total payments GSP'] = np.sum([wandb_dict['gsp payments advertiser ' + str(i)] for i in range(len(myerson_payments_advertisers))]) + wandb_dict['gsp payments reference llm']

        wandb_dict['total payments MP advertisers'] = np.sum([wandb_dict['myerson payments advertiser ' + str(i)] for i in range(len(myerson_payments_advertisers))])
        wandb_dict['total payments GSP advertisers'] = np.sum([wandb_dict['gsp payments advertiser ' + str(i)] for i in range(len(myerson_payments_advertisers))])

        # set_trace()
        if accepted_proposal:
            print("Accepting proposal. Acceptance prob: ", acceptance_probability , "new log prob: ", proposal_prob_dict['target_log_prob'], " old probability log prob: ", self.current_sequence_probability_dict['target_log_prob'])
            self.current_sequence = proposal
            self.current_sequence_probability_dict = proposal_prob_dict

       
        self.current_round += 1

        if self.wandb_logging:
            wandb.log(wandb_dict)

        return self.current_sequence
    
    
    def calculate_payments(self, current_sequence_probability_dict, proposal_sequence_probability_dict, u_variable, accepted_proposal):
        """
        Calculate Myerson and GSP payments for each advertiser.

        """

        advertiser_rewards_proposal = proposal_sequence_probability_dict['advertiser_bid_weighted_log_probs']
        advertiser_rewards_current = current_sequence_probability_dict['advertiser_bid_weighted_log_probs']
        advertiser_bids = advertiser_rewards_proposal - advertiser_rewards_current

        reference_llm_reward_proposal= proposal_sequence_probability_dict['tau_weighted_reference_llm_log_prob'] 
        reference_llm_reward_current = current_sequence_probability_dict['tau_weighted_reference_llm_log_prob']
        reference_llm_bid = reference_llm_reward_proposal - reference_llm_reward_current

        proposal_cond_current_log_prob = proposal_sequence_probability_dict['proposal_cond_current_log_prob'] 
        current_cond_proposal_log_prob = proposal_sequence_probability_dict['current_cond_proposal_log_prob'] 
        proposal_fraction = np.exp(current_cond_proposal_log_prob - proposal_cond_current_log_prob) # this is the c variable in the paper 

        tau = proposal_sequence_probability_dict['tau']

        # --- Calculate Myerson payments --- #
        # NOTE: Here the reference LLM is included in the advertisers 
        all_bids = np.concatenate(([reference_llm_bid], advertiser_bids))
        myerson_payments_advertisers = []
        myerson_case_counter = np.zeros(6)
        for i in range(len(all_bids)):
            current_advertiser_bid = all_bids[i]
            
            # Isolate the bids of the other advertisers
            advertiser_bids_rest = all_bids.copy()
            advertiser_bids_rest[i] = 0
            other_bids_sum= np.sum(advertiser_bids_rest)

            if current_advertiser_bid > 0: 

                if (proposal_fraction * np.exp(current_advertiser_bid + other_bids_sum) ) <= 1: # Case 1 in the paper 
                    print("Myerson payment case 1")
                    myerson_payment = proposal_fraction * np.exp(other_bids_sum) * ( (current_advertiser_bid  - 1) * np.exp(current_advertiser_bid) + 1  )
                    myerson_case_counter[0] += 1

                elif (proposal_fraction *  np.exp(current_advertiser_bid + other_bids_sum) >= 1) and (proposal_fraction * np.exp(other_bids_sum) <=1): # Case 2 in the paper
                    print("Myerson payment case 2")
                    myerson_payment = -other_bids_sum - np.log(proposal_fraction) - 1 + proposal_fraction * np.exp(other_bids_sum)
                    myerson_case_counter[1] += 1

                elif (proposal_fraction *  np.exp(current_advertiser_bid + other_bids_sum) >= 1) and (proposal_fraction * np.exp(other_bids_sum) >=1): # Case 3 in the paper
                    print("Myerson payment case 3")
                    myerson_payment = 0 
                    myerson_case_counter[2] += 1

                else: 
                    raise ValueError('Something went wrong in the Myerson payment calculation, current_advertiser_bid: ', current_advertiser_bid, ' other_bids_sum: ', other_bids_sum)
                
            else: 
                if proposal_fraction * np.exp(other_bids_sum) <= 1: # Case 4 in the paper
                    print("Myerson payment case 4")
                    myerson_payment = proposal_fraction * np.exp(other_bids_sum) * ( (current_advertiser_bid  - 1) * np.exp(current_advertiser_bid) + 1  )
                    myerson_case_counter[3] += 1

                elif (proposal_fraction * np.exp(other_bids_sum) >= 1) and (proposal_fraction * np.exp(other_bids_sum + current_advertiser_bid) <= 1): # Case 5 in the paper
                    print("Myerson payment case 5")
                    myerson_payment = proposal_fraction * (current_advertiser_bid - 1) * np.exp(other_bids_sum + current_advertiser_bid) + other_bids_sum + np.log(proposal_fraction) + 1
                    myerson_case_counter[4] += 1

                elif (proposal_fraction * np.exp(other_bids_sum) >= 1) and (proposal_fraction * np.exp(other_bids_sum + current_advertiser_bid) >= 1): # Case 6 in the paper
                    print("Myerson payment case 6")
                    myerson_payment = 0
                    myerson_case_counter[5] += 1

                else:
                    raise ValueError('Something went wrong in the Myerson payment calculation, current_advertiser_bid: ', current_advertiser_bid, ' other_bids_sum: ', other_bids_sum)
                
            if i == 0:
                myerson_payment_reference_llm = myerson_payment
            else:
                myerson_payments_advertisers.append(myerson_payment)   

        myerson_payments_advertisers = np.array(myerson_payments_advertisers)

        # --- Calculate GSP payments --- #
        if accepted_proposal:
            log_u_variable = np.log(u_variable)
            proposal_distribution_log = proposal_cond_current_log_prob - current_cond_proposal_log_prob # This is log ( g(y'|y_t) / g(y_t|y') ) # NOTE: GSP payments are with respect to this. 
            gsp_payments_advertisers = []
            for i in range(len(advertiser_bids)):
                # make a copy of the advertiser bids 
                advertiser_bids_copy = advertiser_bids.copy()
                # remove the current advertiser bid from the copy
                advertiser_bids_copy[i] = 0

                # calculate the sum of the remaining bids
                sum_of_remaining_bids = np.sum(advertiser_bids_copy)

                # calculate the GSP payment for the current advertiser
                gsp_payment = tau * log_u_variable + tau *  proposal_distribution_log - sum_of_remaining_bids - reference_llm_bid # NOTE: The reference_llm_bid is already tau weighted
                gsp_payments_advertisers.append(gsp_payment)

            gsp_payment_reference_llm = tau * log_u_variable + tau * proposal_distribution_log - np.sum(advertiser_bids) 
        
        else: 
            # if the sequence does not get accepted. 
            # No one pays in GSP payments, as the critical bids were not met
            # And advertisers have already paied up to the current sequence 
            gsp_payments_advertisers = [0] * len(advertiser_bids)
            gsp_payment_reference_llm = 0

        gsp_payments_advertisers = np.array(gsp_payments_advertisers)

        # set_trace()

        return myerson_payments_advertisers, myerson_payment_reference_llm, myerson_case_counter, gsp_payments_advertisers, gsp_payment_reference_llm

    
    def sample(self, n_samples):
        """
        Generate samples using the Metropolis-Hastings algorithm.
        
        :param n_samples: Number of samples to generate.
        :return: List of samples.
        """
        samples = [self.current_sequence]
        for _ in range(n_samples - 1):
            samples.append(self.step())
        return samples
    

class target_probability: 
    def __init__(self, reference_LLM, reference_tokenizer, user_prompt, 
                advertiser_prompts, advertiser_cardinal_bids, advertiser_names, proposal_expansion, device,
                tau = 1, remove_start_token = False):
        """
        Initialize the target probability distribution for the Metropolis-Hastings sampler.
        
        reference_LLM: Language model to use as reference distribution
        reference_tokenizer: Tokenizer for reference_LLM
        user_prompt: Original user prompt 
        advertiser_prompts: List of prompts for each advertiser
        advertiser_cardinal_bids: List of cardinal bids for each advertiser
        advertiser_names: List of the names of the advertisers e.g. ['KitchenFix', 'Easybake']
        proposal_prompt: Prompt for the "proposal" distribution (e.g. "Answer the follwoign query, trying to mention...")
        tau: Balances between advertiser and user preferences (tau = 1 means equal weight)
        """
        self.reference_LLM = reference_LLM
        self.reference_tokenizer = reference_tokenizer
        self.user_prompt = user_prompt
        self.advertiser_prompts = [prompt + user_prompt for prompt in advertiser_prompts]
        self.advertiser_cardinal_bids = np.array(advertiser_cardinal_bids)
        self.tau = tau
        self.remove_start_token = remove_start_token
        
        self.device = device
        self.user_prompt_encoded = reference_tokenizer.encode(user_prompt, return_tensors="pt").to(self.device)
        self.advertiser_prompts_encoded = [reference_tokenizer.encode(prompt + user_prompt, return_tensors="pt").to(self.device) for prompt in advertiser_prompts]
        self.advertiser_names = advertiser_names

        # Encode proposal prompt
        if proposal_expansion is not None:
            self.proposal_prompt = proposal_expansion + user_prompt
            self.proposal_prompt_encoded = reference_tokenizer.encode(self.proposal_prompt, return_tensors="pt").to(self.device)
        else:
            self.proposal_prompt = None
            self.proposal_prompt_encoded = None
        



    def evaluate(self, sequence, current_sequence, encoded_sequence):
        """
        Compute the probability density of the target distribution.
        Returns various probabilities metrics as a dictionary, that can be used by the MH sampler.

        :param sequence: Sequence to evaluate.
        :param current_sequence: Current sequence in the MH sampler. Needed to calculate g(y'|y) and g(y|y')
        :param encoded_sequence (bool): If True, the sequence is already encoded. If False, the sequence is a string and needs to be encoded.
        :param return_log_prob: If True, return the log probability instead of the probability.
        :param normalization_strategy: Normalization strategy to use. Can be one of "number_of_tokens", "number_of_bytes", "number_of_characters", or "None".
        :param return_individual_advertiser_log_probs: If True, return the log probability of the sequence under each advertiser's LLM.
        """
       
        # Calculate probability of sequence under reference LLM
        if encoded_sequence: 
            input_sequence = self.user_prompt_encoded
            advertiser_prompts = self.advertiser_prompts_encoded
            # remove the pudding from the sequence
            sequence = sequence[sequence != 0]
        else:
            input_sequence = self.user_prompt
            advertiser_prompts = self.advertiser_prompts

        # set_trace()
        reference_llm_prob, reference_llm_log_prob = calculate_sequence_probability(model = self.reference_LLM, tokenizer = self.reference_tokenizer, input_sequence= input_sequence, output_sequence= sequence, remove_start_token= self.remove_start_token, encoded_sequence= encoded_sequence, device= self.device)
        
        print('reference_llm_log_prob: ', reference_llm_log_prob)

        if self.proposal_prompt is not None:
            proposal_distribution_prob, proposal_distribution_log = calculate_sequence_probability(model = self.reference_LLM, tokenizer = self.reference_tokenizer, input_sequence= self.proposal_prompt_encoded, output_sequence= sequence, remove_start_token= self.remove_start_token, encoded_sequence= encoded_sequence, device= self.device)
                                  
        # Calculate probability of sequence under advertiser LLMs
        advertiser_log_probs = []
        for i in range(len(self.advertiser_prompts)):
            advertiser_prob, advertiser_log_prob = calculate_sequence_probability(model = self.reference_LLM, tokenizer = self.reference_tokenizer,
                input_sequence= advertiser_prompts[i], output_sequence= sequence, remove_start_token= self.remove_start_token, encoded_sequence= encoded_sequence, device = self.device)
            advertiser_log_probs.append(advertiser_log_prob)


        advertiser_log_probs = np.array(advertiser_log_probs)
        
        # Calculate welfare of advertisers
        advertiser_rewards_unweighted = advertiser_log_probs - reference_llm_log_prob  # According to our theory 
        advertiser_rewards_weighted = advertiser_rewards_unweighted * self.advertiser_cardinal_bids

        advertiser_welfare = np.sum(advertiser_rewards_weighted)

        print('advertiser log probs: ', advertiser_log_probs)



        # Calculate probability of sequence under target distribution
        # Implicit assumption: An advertiser's reward is the log probability of their LLM times their cardinal bid
        target_log_prob = reference_llm_log_prob  +  (advertiser_welfare / self.tau)
        target_prob = np.exp(target_log_prob)
        
        # Calculate all possible normalization factors, as well as other statistics for logging 
        advertiser_mentioned_dict = {}
        if not encoded_sequence:
            number_of_tokens = len(self.reference_tokenizer.encode(sequence, return_tensors='pt')[0]) 
            number_of_bytes = len(sequence.encode('utf-8'))
            number_of_characters = len(sequence)
            for i in range(len(self.advertiser_names)): 
                advertiser_mentioned_dict[f'advertiser {i} mentioned'] = 1 if self.advertiser_names[i] in sequence else 0
            advertiser_mentioned_dict['number_of_advertisers_mentioned'] = sum(advertiser_mentioned_dict.values())

        else: 
            number_of_tokens = len(sequence)
            decoded_sequence = self.reference_tokenizer.decode(sequence)
            number_of_bytes = len(decoded_sequence.encode('utf-8'))
            number_of_characters = len(decoded_sequence)
            for i in range(len(self.advertiser_names)):
                advertiser_mentioned_dict[f'advertiser {i} mentioned'] = 1 if self.advertiser_names[i] in decoded_sequence else 0
            advertiser_mentioned_dict['number_of_advertisers_mentioned'] = sum(advertiser_mentioned_dict.values())

        result_dict = {"target_log_prob": target_log_prob, "target_prob": target_prob, "reference_llm_log_prob": reference_llm_log_prob, "tau_weighted_reference_llm_log_prob": reference_llm_log_prob * self.tau,
                        "reference_llm_prob": reference_llm_prob, "advertiser_log_probs": advertiser_log_probs, 
                        "proposal_distribution_log_prob": proposal_distribution_log, "proposal_distribution_prob": proposal_distribution_prob, 
                        "advertiser_rewards_unweighted": advertiser_rewards_unweighted, "advertiser_rewards_weighted": advertiser_rewards_weighted,
                        "advertiser_welfare": advertiser_welfare,
                        "number_of_tokens": number_of_tokens, "number_of_bytes": number_of_bytes, "number_of_characters": number_of_characters, 'tau': self.tau,
                        }
        result_dict.update(advertiser_mentioned_dict)
        
        
        # Calculate g(y'|y) and g(y|y')
        # g(y'|y) is the probability of proposing y' given y
        # g(y|y') is the probability of proposing y given y'
        if current_sequence is not None:
            set_trace() # NOTE: MAKE SURE STUFF IS PROPERLY UNPADDED HERE! 
            proposal_cond_current_prob, proposal_cond_current_log_prob = calculate_sequence_probability(model = self.reference_LLM, tokenizer = self.reference_tokenizer,
                    input_sequence= self.proposal_prompt + current_sequence, output_sequence= sequence, remove_start_token= self.remove_start_token, encoded_sequence= encoded_sequence, device= self.device)
            
            current_cond_proposal_prob, current_cond_proposal_log_prob = calculate_sequence_probability(model = self.reference_LLM, tokenizer = self.reference_tokenizer,
                    input_sequence= self.proposal_prompt + sequence, output_sequence= current_sequence, remove_start_token= self.remove_start_token, encoded_sequence= encoded_sequence, device= self.device)
            
            result_dict["proposal_cond_current_log_prob"] = proposal_cond_current_log_prob
            result_dict["current_cond_proposal_log_prob"] = current_cond_proposal_log_prob
            result_dict["proposal_cond_current_prob"] = proposal_cond_current_prob
            result_dict["current_cond_proposal_prob"] = current_cond_proposal_prob


        return result_dict


    


if __name__ == "__main__":
    wandb_logging = True 
    number_of_samples = 100
    model_name = "google/flan-t5-large"

     # User Input Sequence 
    user_input_sequence = "How do you make cookies?"

    # Advertisement prompts 
    advertisement_prompts = ["Answer the following prompt as if you were a creative advertiser for a company that makes kitchen applicances named KitchenFix: ", 
        "Answer the following prompt as if you were a creative advertiser for a company that produces baking ingredients called EasyBake:"]
    proposal_prompt = "Rephrase the following text, while trying to maintain any references to brands: "

    # Initialize weights and biases
    initialization_dict = {
        'model_name': model_name,
        'user_input_sequence': user_input_sequence,
        'advertisement_prompts': advertisement_prompts,
        'proposal_prompt': proposal_prompt
    }

    if wandb_logging:
        wandb.init(project="LLM-MH-v0.2-TEST", config=initialization_dict)

    # Load model and tokenizer (flan)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

   
    
    # Create target probability distribution
    target_distribution = target_probability(reference_LLM = model, reference_tokenizer = tokenizer, user_prompt = user_input_sequence,
        advertiser_prompts = advertisement_prompts, advertiser_cardinal_bids = [1,1], proposal_prompt= proposal_prompt, tau = 1, remove_start_token = False)
    

    # Create initial value for MH
    # input_expansion = "Answer the following prompt, while trying to mention KitchenFix, who makes kitchen applicances, and EasyBake, who produces baking ingredients: "
    # original_output_sequences = sample_sentences(model = model, tokenizer = tokenizer, input_sequence = input_expansion + user_input_sequence, 
    #                                     max_length = 100, top_k=50,top_p=0.95,num_return_sequences=1, print_output= False)
    # original_output_sequence = tokenizer.decode(original_output_sequences[0], skip_special_tokens=True)

    # ---- Trying different ways of creating the initial prompt ---
    input_expansion = "Answer the query. Mention KitchenFix, who makes kitchen applicances, and EasyBake, who produces baking ingredients. "
    original_output_sequences_encoded = sample_sentences(model = model, tokenizer = tokenizer, input_sequence =user_input_sequence + input_expansion, 
                                        max_length = 100, top_k=10,top_p=0.95,num_return_sequences=1, print_output= False)
    original_output_sequence= tokenizer.decode(original_output_sequences_encoded[0], skip_special_tokens=True)

    # # print all replies 
    # for i in range(len(original_output_sequences)):
    #     print('-'*100)
    #     print(f'Sample {i}: ', tokenizer.decode(original_output_sequences[i], skip_special_tokens=True))
    # set_trace()

    # Calculate probability of original output sequence
    print('-'*100)
    print("Original sequence: {}".format(original_output_sequence))


    

    # Create Metropolis-Hastings sampler
    sampler = LLM_MetropolisHastings(reference_LLM = model, reference_tokenizer = tokenizer, initial_sequence = original_output_sequence,
        proposal_prompt = proposal_prompt, target_distribution = target_distribution, wandb_logging= wandb_logging)
    
    # Sample from Metropolis-Hastings sampler
    samples = sampler.sample(number_of_samples)


    # Print samples
    print('-'*100)
    print("Samples: ")
    for i,sample in enumerate(samples):
        print(f'Sample {i}:', sample)

    
    if wandb_logging:
        wandb.run.summary["Sequences per round"] = samples
        wandb.finish()



    # # Combining original input sequence with advertisement prompt
    # input_sequence = advertisement_prompts[0] + user_input_sequence
    # input_ids = tokenizer(input_sequence, return_tensors="pt").input_ids
    # original_output_sequence = sample_sentences(model = model, tokenizer = tokenizer, input_sequence = input_sequence,
    #                                     max_length = 100, top_k=50,top_p=0.95,num_return_sequences=1, print_output= False)
    # print('-'*50)
    # print("Ad1 sequence: {}".format(tokenizer.decode(original_output_sequence[0], skip_special_tokens=True)))


    # input_sequence = advertisement_prompts[1] + user_input_sequence
    # input_ids = tokenizer(input_sequence, return_tensors="pt").input_ids
    # original_output_sequence = sample_sentences(model = model, tokenizer = tokenizer, input_sequence = input_sequence,
    #                                     max_length = 100, top_k=50,top_p=0.95,num_return_sequences=1, print_output= False)
    # print('-'*50)
    # print("Ad2 sequence: {}".format(tokenizer.decode(original_output_sequence[0], skip_special_tokens=True)))



# --- Parameters to log at the beginning of the experiment --- #
# 1. Initial user prompt
# 2. Target distribution (i.e., advertiser LLMs + cardinal bids)
# 3. Proposal distribution (i.e., reference LLM proposal prompt)
# 4. Normalization strategy
# 5. Number of rounds / stopping criteria
# 6. LLM architecture
# 7. LLM parameters
# 8. Prompt used to generate initial sequence



# --- Things to log at every state --- # 
# 1. Current sequence + Proposal sequence
# 2. Current sequence (+ proposal sequence) log probability for target distribution
# 3. Current sequence (+ proposal sequence) log probability for proposal distribution
# 4. Current sequence (+ proposal sequence) log probability for reference LLM
# 5. Current sequence (+ proposal sequence) log probability for advertiser LLMs (scaled by cardinal bids AND unscaled)
# 6. All above values normalized by number of tokens in sequence
# 7. Acceptance probability
# 8. Acceptance decision
# 9. Number of tokens in sequence
# 10. Number of bytes in sequence
# 11. Number of characters in sequence
# 12. GSP payments in each round (+ also with offset)
# 13. Myerson payments in each round
# 14. Cumulative payments for each advertiser up to each round. 
# 15. Number of times each advertiser is mentioned
# 16. Number of times each advertiser is mentioned in each round
# 17. Number of times each advertiser is mentioned in each round, normalized by number of tokens in sequence
# 18. Time taken for each round
# 19. Time taken for each round to generate proposal sequence
# 20. Time taken for each round to evaluate proposal sequence
# 21. Reward gain per agent
# 22. Expected utility gain per agent (i.e., reward gain - payment)



# {'round': 1, 'current_sequence': 'EasyBake produces cookie recipes, containing baking powder, sugar, butter, and vanilla.', 
#  'proposal sequence': 'EasyBake makes baking powder, sugar, butter, and vanilla cookies.', 
#  'current probability': 1.996411692887478e-246, 'current log probability': -565.7445814609528, 
#  'proposal probability': 1.1853935861232035e-174, 'proposal log probability': -400.47973132133484, 
#  'proposal reference llm probability': -129.40221095085144, 'proposal reference llm log probability': 6.328981410705241e-57, 
#  'advertiser welfare': -271.0775203704834, 'number of tokens': 17, 'number of bytes': 65, 'number of characters': 65, 
#  'acceptance probability': 1, 'accepted proposal': 1, 
#  'myerson payments advertiser 0': 0, 'gsp payments advertiser 0': -165.03156432611576, 
#  'indirect utility MP advertiser 0': 59.0269992351532, 'indirect utility GSP advertiser 0': 224.05856356126895, 
#  'myerson payments advertiser 1': 0, 'gsp payments advertiser 1': -160.07774338228336, 
#  'indirect utility MP advertiser 1': 52.16467261314392, 'indirect utility GSP advertiser 1': 212.24241599542728, 
#  'myerson payments reference llm': 0, 'gsp payments reference llm': -165.03156432611576, 
#  'indirect utility MP reference llm': 54.0731782913208, 'indirect utility GSP reference llm': 219.10474261743656, 
#  'total indirect utility MP': 165.26485013961792, 'total indirect utility GSP': 655.4057221741327, 
#  'total payments MP': 0, 'total payments GSP': -490.14087203451487, 
#  'total payments MP advertisers': 0, 'total payments GSP advertisers': -325.1093077083991}